{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "777fc40d",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# PySpark Huggingface Inferencing\n",
    "### Sentence Transformers with PyTorch\n",
    "\n",
    "In this notebook, we demonstrate distributed inference with the Huggingface SentenceTransformer library for sentence embedding.  \n",
    "From: https://huggingface.co/sentence-transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c5f0d0a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
    "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
    "import os\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "731faab7-a700-46f8-bba5-1c8764e5eacb",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n",
    "\n",
    "sentence = ['This framework generates embeddings for each input sentence']\n",
    "embedding = model.encode(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "96eea5ca-3cf7-46e3-b40c-598538112d24",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-0.17621444  0.1206013  -0.29362372 -0.22985819 -0.08229247  0.2377093\n",
      "  0.33998525 -0.7809643   0.11812777  0.16337365 -0.13771524  0.24028276\n",
      "  0.4251256   0.17241786  0.10527937  0.5181643   0.062222    0.39928585\n",
      " -0.18165241 -0.58557856]\n"
     ]
    }
   ],
   "source": [
    "print(embedding[0][:20])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "546eabe0",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dbda3e66-005a-4ad0-8017-c1cc7cbf0058",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.types import *\n",
    "from pyspark import SparkConf\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.sql.functions import pandas_udf, col, struct\n",
    "from pyspark.ml.functions import predict_batch_udf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b525c5c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import datasets\n",
    "from datasets import load_dataset\n",
    "datasets.disable_progress_bars()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58e7c1bc",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific Spark configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5a013217",
   "metadata": {},
   "outputs": [],
   "source": [
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3c003d",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "23ec67ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:40:01 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "25/02/04 13:40:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/02/04 13:40:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "\n",
    "    conf.set(\"spark.executor.cores\", \"8\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "\n",
    "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cfd1394",
   "metadata": {},
   "source": [
    "Load the IMBD Movie Reviews dataset from Huggingface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9bc1edb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"imdb\", split=\"test\")\n",
    "dataset = dataset.to_pandas().drop(columns=\"label\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59c71bff",
   "metadata": {},
   "source": [
    "#### Create PySpark DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "836e5f84-12c6-4c95-838e-53de7e46a20b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "StructType([StructField('text', StringType(), True)])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.createDataFrame(dataset).repartition(8)\n",
    "df.schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "36703d23-37a3-40df-b09a-c68206d285b6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "25000"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1f122ae3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:40:08 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.<br /><br />The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.<br /><br />The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.<br /><br />I really got nothing much left to say except, give us back CKY2K, cause Bam suck..<br /><br />I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.take(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "14fd59fb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:40:08 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
     ]
    }
   ],
   "source": [
    "data_path = \"spark-dl-datasets/imdb_test\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path = \"dbfs:/FileStore/\" + data_path\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bb083ec",
   "metadata": {},
   "source": [
    "#### Load and preprocess DataFrame\n",
    "\n",
    "Define our preprocess function. We'll take the first sentence from each sample as our input for translation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2510bdd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "@pandas_udf(\"string\")\n",
    "def preprocess(text: pd.Series) -> pd.Series:\n",
    "    return pd.Series([s.split(\".\")[0] for s in text])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5bb28548",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------------------------------------------------------------------------------------------+\n",
      "|                                                                                               input|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n",
      "|                          There were two things I hated about WASTED : The directing and the script |\n",
      "|                                I'm rather surprised that anybody found this film touching or moving|\n",
      "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n",
      "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n",
      "|                                                                     This movie has been done before|\n",
      "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n",
      "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n",
      "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n",
      "|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did such talented actors get involved in such mindles...|\n",
      "|                              There is not one character on this sitcom with any redeeming qualities|\n",
      "|                                  Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n",
      "|                                 My wife rented this movie and then conveniently never got to see it|\n",
      "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n",
      "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n",
      "|                           you will likely be sorely disappointed by this sequel that's not a sequel|\n",
      "|                          If I was British, I would be embarrassed by this portrayal of incompetence|\n",
      "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n",
      "|                        This show is like watching someone who is in training to someday host a show|\n",
      "|                                                                                                Sigh|\n",
      "+----------------------------------------------------------------------------------------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Limit to N rows, since this can be slow\n",
    "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
    "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n",
    "df.show(truncate=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "014eae88",
   "metadata": {},
   "source": [
    "## Inference using Spark DL API\n",
    "\n",
    "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
    "\n",
    "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n",
    "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f780c026-0f3f-4aea-8b61-5b3dbae83fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_batch_fn():\n",
    "    import torch\n",
    "    from sentence_transformers import SentenceTransformer\n",
    "\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n",
    "    def predict(inputs):\n",
    "        return model.encode(inputs.tolist())\n",
    "    return predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "f5c88ddc-ca19-4430-8b0e-b9fae143b237",
   "metadata": {},
   "outputs": [],
   "source": [
    "encode = predict_batch_udf(predict_batch_fn,\n",
    "                           return_type=ArrayType(FloatType()),\n",
    "                           batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "85344c22-4a4d-4cb0-8771-5836ae2794db",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Stage 18:=====================>                                    (3 + 5) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10.6 ms, sys: 4.83 ms, total: 15.4 ms\n",
      "Wall time: 4.23 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn\n",
    "embeddings = df.withColumn(\"embedding\", encode(struct(\"input\")))\n",
    "results = embeddings.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c23bb885-6ab0-4471-943d-4c10414100fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6.7 ms, sys: 2.44 ms, total: 9.15 ms\n",
      "Wall time: 163 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "embeddings = df.withColumn(\"embedding\", encode(\"input\"))\n",
    "results = embeddings.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "93bc6da3-d853-4233-b805-cb4a46f4f9b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.37 ms, sys: 2.73 ms, total: 8.1 ms\n",
      "Wall time: 232 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "embeddings = df.withColumn(\"embedding\", encode(col(\"input\")))\n",
    "results = embeddings.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "2073616f-7151-4760-92f2-441dd0bfe9fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------------------------------------+--------------------------------------------------+\n",
      "|                                             input|                                         embedding|\n",
      "+--------------------------------------------------+--------------------------------------------------+\n",
      "|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\n",
      "|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\n",
      "|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\n",
      "|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\n",
      "|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\n",
      "|                   This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\n",
      "|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\n",
      "|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\n",
      "|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\n",
      "|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\n",
      "|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\n",
      "|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\n",
      "|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\n",
      "|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\n",
      "|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\n",
      "|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\n",
      "|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\n",
      "|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\n",
      "|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\n",
      "|                                              Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\n",
      "+--------------------------------------------------+--------------------------------------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "embeddings.show(truncate=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c9c6535",
   "metadata": {},
   "source": [
    "## Using Triton Inference Server\n",
    "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \n",
    "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
    "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
    "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "772e337e-1098-4c7b-ba81-8cb221a518e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "759385ac",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "485fb0de",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import TritonServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ece5c38a",
   "metadata": {},
   "source": [
    "Define the Triton Server function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "69d0c93a-bb0b-46c5-9d28-7b08a2e70964",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_server(ports):\n",
    "    import time\n",
    "    import signal\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    from sentence_transformers import SentenceTransformer\n",
    "    from pytriton.decorators import batch\n",
    "    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
    "    from pytriton.triton import Triton, TritonConfig\n",
    "    from pyspark import TaskContext\n",
    "\n",
    "    print(f\"SERVER: Initializing sentence transformer on worker {TaskContext.get().partitionId()}.\")\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n",
    "    print(f\"SERVER: Using {device} device.\")\n",
    "\n",
    "    @batch\n",
    "    def _infer_fn(**inputs):\n",
    "        sentences = np.squeeze(inputs[\"text\"])\n",
    "        print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
    "        decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
    "        embeddings = model.encode(decoded_sentences)\n",
    "        return {\n",
    "            \"embeddings\": embeddings,\n",
    "        }\n",
    "\n",
    "    workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
    "    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
    "    with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
    "        triton.bind(\n",
    "            model_name=\"SentenceTransformer\",\n",
    "            infer_func=_infer_fn,\n",
    "            inputs=[\n",
    "                Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
    "            ],\n",
    "            outputs=[\n",
    "                Tensor(name=\"embeddings\", dtype=np.float32, shape=(-1,)),\n",
    "            ],\n",
    "            config=ModelConfig(\n",
    "                max_batch_size=64,\n",
    "                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\n",
    "            ),\n",
    "            strict=True,\n",
    "        )\n",
    "\n",
    "        def _stop_triton(signum, frame):\n",
    "            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n",
    "            print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
    "            triton.stop()\n",
    "\n",
    "        signal.signal(signal.SIGTERM, _stop_triton)\n",
    "\n",
    "        print(\"SERVER: Serving inference\")\n",
    "        triton.serve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79532110",
   "metadata": {},
   "source": [
    "#### Start Triton servers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b0371c8",
   "metadata": {},
   "source": [
    "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n",
    "- Find available ports for HTTP/gRPC/metrics\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e66e8927",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"SentenceTransformer\"\n",
    "server_manager = TritonServerManager(model_name=model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "040df0dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n",
    "server_manager.start_servers(triton_server)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fd19fae",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddeadc74",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c42d1578",
   "metadata": {},
   "outputs": [],
   "source": [
    "host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "807dbc45",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    import numpy as np\n",
    "    from pytriton.client import ModelClient\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
    "\n",
    "    def infer_batch(inputs):\n",
    "        with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
    "            flattened = np.squeeze(inputs).tolist()\n",
    "            # Encode batch\n",
    "            encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
    "            encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
    "            # Run inference\n",
    "            result_data = client.infer_batch(encoded_batch_np)\n",
    "            return result_data[\"embeddings\"]\n",
    "        \n",
    "    return infer_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "9c712b8f-6eb4-4fb8-9f0a-04feef847fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "encode = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n",
    "                           return_type=ArrayType(FloatType()),\n",
    "                           input_tensor_shapes=[[1]],\n",
    "                           batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af174106",
   "metadata": {},
   "source": [
    "#### Load and preprocess DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "2969d502-e97b-49d6-bf80-7d177ae867cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "@pandas_udf(\"string\")\n",
    "def preprocess(text: pd.Series) -> pd.Series:\n",
    "    return pd.Series([s.split(\".\")[0] for s in text])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "c8f1e6d6-6519-49e7-8465-4419547633b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:40:22 WARN CacheManager: Asked to cache already cached data.\n"
     ]
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
    "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf0ee731",
   "metadata": {},
   "source": [
    "#### Run Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "934c1a1f-b126-45b0-9c15-265236820ad3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 5.59 ms, sys: 5.1 ms, total: 10.7 ms\n",
      "Wall time: 605 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass caches model/fn\n",
    "embeddings = df.withColumn(\"embedding\", encode(struct(\"input\")))\n",
    "results = embeddings.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "f84cd3f6-b6a8-4142-859a-91f3c183457b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.57 ms, sys: 4.36 ms, total: 6.93 ms\n",
      "Wall time: 161 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "embeddings = df.withColumn(\"embedding\", encode(\"input\"))\n",
    "results = embeddings.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "921a4c01-e296-4406-be90-86f20c8c582d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 7.06 ms, sys: 605 μs, total: 7.67 ms\n",
      "Wall time: 191 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "embeddings = df.withColumn(\"embedding\", encode(col(\"input\")))\n",
    "results = embeddings.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "9f67584e-9c4e-474f-b6ea-7811b14d116e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------------------------------------+--------------------------------------------------+\n",
      "|                                             input|                                         embedding|\n",
      "+--------------------------------------------------+--------------------------------------------------+\n",
      "|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\n",
      "|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\n",
      "|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\n",
      "|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\n",
      "|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\n",
      "|                   This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\n",
      "|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\n",
      "|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\n",
      "|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\n",
      "|MINOR PLOT SPOILERS AHEAD!!!<br /><br />How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\n",
      "|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\n",
      "|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\n",
      "|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\n",
      "|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\n",
      "|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\n",
      "|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\n",
      "|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\n",
      "|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\n",
      "|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\n",
      "|                                              Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\n",
      "+--------------------------------------------------+--------------------------------------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "embeddings.show(truncate=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e3b0077c-785f-41af-9fa9-812e7fb63810",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Stop Triton Server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "ef780e30",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-04 13:40:23,196 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-04 13:40:28,390 - INFO - Sucessfully stopped 1 servers.                 \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "e82b9518-da7b-4ebc-8990-c8ab909bec18",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a60f2d-295a-4270-a2fd-16559962edda",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
